import time
import os
import re
import pickle
from functools import wraps
from unidecode import unidecode
from datetime import datetime, timedelta, date
import numpy as np
import matplotlib.pyplot as plt
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
from scipy.sparse import vstack
import polars as pl


def open_pickle(path):
    """opens a pickle file"""
    with open(path, 'rb') as handle:
        return pickle.load(handle)


def write_pickle(obj, path):
    """writes an object to a pickle file"""
    with open(path, 'wb') as handle:
        pickle.dump(obj, handle, protocol=pickle.HIGHEST_PROTOCOL)
     

def timeit(func):
    """decorator to time a function"""
    @wraps(func)
    def timeit_wrapper(*args, **kwargs):
        start_time = time.perf_counter()
        result = func(*args, **kwargs)
        end_time = time.perf_counter()
        total_time = end_time - start_time
        print(f'Function {func.__name__} took {total_time:.4f} seconds')
        return result
    return timeit_wrapper


def create_time_series(
        df, first_date_col='first_date', last_date_col='last_date', value_col=None):
    """creates an aggregate time series for df with (first_date, last_date, value) columns;
    value_col is used if there is more than one count per line"""

    if not value_col:
        value_col = 'vacancies'
        df = df.with_columns(pl.lit(1).alias(value_col))

    cols = [first_date_col, last_date_col]
    date_range = pl.date_range(
        df[first_date_col].min(), df[last_date_col].max(), interval='1d', eager=True
    )

    return (
        df.select(cols + [value_col])
        .groupby(cols)
        .agg(pl.sum(value_col))
        .with_columns(
            [
                (
                    (
                        (pl.col(first_date_col) <= date) &
                        (date <= pl.col(last_date_col))
                        )
                    * pl.col(value_col)
                )
                .alias(datetime.strftime(date, format="%Y-%m-%d"))
                for date in date_range
            ]
        )
        .drop(cols + [value_col])
        .sum()
        .transpose(include_header=True)
        .select(
            [
                pl.col('column').str.strptime(pl.Date, '%Y-%m-%d').alias('date'),
                pl.col('column_0').alias(value_col)
            ]
        )
    )


plt.style.use('tableau-colorblind10')

def plot_data(data, start_date=datetime(2018, 1, 1), end_date=datetime.now(), ylim=None):
    fig, ax = plt.subplots(figsize=(12, 8))
    for label in data.keys():
        df = (
            data[label]
            .filter(pl.col('date').is_between(start_date, end_date, closed='both'))
            .sort(by='date')
            )
        plt.plot(
            df['date'].to_list(), df['vacancies'].to_list(), label=label,
            linewidth=2.5)
    if ylim:
        ax.set_ylim(ylim)
    plt.legend()
    plt.grid()
    plt.xlabel('Date')
    plt.ylabel('Job postings');


def compute_correlation(
    data, series0, series1, start_date=datetime(2018, 1, 1),
    end_date=datetime.now()
    ):
    corr = np.round(
        data[series0]
        .join(data[series1], how='left', on='date')
        .filter(pl.col('date').is_between(start_date, end_date))
        .drop('date')
        .rename({'vacancies': series0, 'vacancies_right': series1})
        .drop_nulls()
        .corr()
        [0, 1],
        3
    )
    print(f'correlation({series0}, {series1}) since {str(start_date.date())}: {corr}')


def clean_city_name(df, city_name='city'):
    """"cleans city"""
    assert type(df) is pl.DataFrame, 'df must be a polars DataFrame; use clean_city_name_pandas for pd.DataFrame'
    return df.with_columns(
        pl.col(city_name)
        .apply(unidecode)  # remove accents
        .str.to_lowercase()
        .str.replace_all(r'[-/]', ' ')  # replace - and / by space
        .str.replace_all(r'[^\w\s]', '')  # replace non-alphanumeric characters
        .str.replace_all(r'\bst\b', 'saint')  # replace st by saint
        .str.strip()
    )

stopwords = [
    'inc', 'incorporated', 'group', 'llc', 'ltd', 'limited', 'corp', 'corporation',
    'company', 'co', 'groupe', 'canada']

def replace_stopwords(name, stopwords=stopwords):
    return ' '.join([token for token in name.split()
                     if token.lower() not in set(stopwords)])


def clean_company_name(df, company_name='company_name'):
    """"cleans company_name"""
    assert type(df) is pl.DataFrame, 'df must be a polars DataFrame; use clean_company_name_pandas for pd.DataFrame'
    return df.with_columns(
        pl.col(company_name)
        .apply(unidecode)
        .str.to_lowercase()
        .str.replace_all(r'[-/]', ' ')
        .str.replace_all(r'[^\w\s]', '')
        .str.strip()
        .apply(replace_stopwords)
)


def clean_text(series):
    """solves encoding problems and remove accents"""
    series = (series
        .str.normalize('NFKD')
        .str.encode('ascii', errors='ignore')
        .str.decode('utf-8')
        )
    return series


def clean_city_name_pandas(df, city='city'):
    """"cleans city"""
    df[city] = clean_text(
        df[city]
        .str.lower()
        .str.replace(r'-', ' ', regex=True)
        .str.replace(r'[^\w\s]','', regex=True)
        .str.replace(r'\bst\b', 'saint', regex=True)
        .str.strip()
    )


def clean_company_name_pandas(df, company_name='company_name'):
    """"cleans company_name"""
    df['company_name'] = clean_text(
        df[company_name]
        .str.lower()
        .str.replace(r'[-/]', ' ', regex=True)
        .str.replace(r'[^\w\s]','', regex=True)
        .apply(replace_stopwords)
        .str.strip()
    )

class FileNames:
    """class to manage filenames; used for duration and inflows files"""
    def __init__(
        self, path_df, new_filenames_format=r'CA_job_postings_202\d{5}.csv',
        path_new_filenames='/home/asd/ae_data_onboarding/prod/indeed/data/'
    ):
        self.path_df = path_df
        self.new_filenames_format = new_filenames_format
        self.path_new_filenames = path_new_filenames
        self.load_df()

    def load_df(self):
        self.df = pl.read_csv(self.path_df, dtypes=[pl.Utf8, pl.Date])
    
    def get_new_filenames(self):
        return [
            name for name in os.listdir(self.path_new_filenames)
            if re.match(self.new_filenames_format, name)
            and name not in self.df['filename']
        ]

    def parse_dates(self, filenames):
        pattern0 = r'(202\d{5})'
        pattern1 = r'(202\d{1}_\d{2}_\d{2})'

        def parse_date_string(string):
            if match := re.search(pattern0, string):
                return datetime.strptime(match.group(0), '%Y%m%d').date()
            if match := re.search(pattern1, string):
                return datetime.strptime(match.group(0), '%Y_%m_%d').date()
            else:
                return None
        
        return [parse_date_string(filename) for filename in filenames]
    
    def sort_df_by_date(self):
        self.df = self.df.sort('date', descending=True)

    def write_df(self):
        self.sort_df_by_date()
        self.df.write_csv(self.path_df)

    def update_filenames(self, since_last_days=60):
        filenames = self.get_new_filenames()
        dates = self.parse_dates(filenames)
        date_60_days_ago = date.today() - timedelta(days=since_last_days)
        new_df = (
            pl.DataFrame({'filename': filenames, 'date': dates})
            .filter(pl.col('date') > date_60_days_ago)
        )
        if len(new_df) > 0:
            print(f'added {new_df}')
            self.df = self.df.extend(new_df)
            self.sort_df_by_date()
            self.write_df()

    def to_list(self):
        return self.df['filename'].to_list()

    def __call__(self):
        return self.df
    

class FileNames1:
    """class to manage filenames; used for duration and inflows files, when I have different vintages of the filelist"""
    def __init__(
        self, 
        path_df_dir, # Change: Now using the directory instead of a fixed filename
        file_type = "duration", # Default to duration, but can be inflows (inputed outside)
        new_filenames_format=r'CA_job_postings_202\d{5}.csv',
        path_new_filenames='/home/asd/ae_data_onboarding/prod/indeed/data/'
    ):
        self.path_df_dir = path_df_dir
        self.file_type = file_type      # Define whether this is inflows or duration
        self.new_filenames_format = new_filenames_format
        self.path_new_filenames = path_new_filenames
        self.file_prefix = self.determine_file_prefix(file_type)    # Find the latest file based on prefix
        self.path_df = self.find_latest_df_file()   # Automatically detect latest file based on a prefix
        self.load_df()
    
    def determine_file_prefix(self, file_type):
        """Determine the file prefix based on a the type (inflows or duration)"""
        if file_type == "inflows":
            return "df_inflows_filenames"
        elif file_type == "duration":
            return "df_duration_filenames"
        else:
            raise ValueError(f"Invalid file_type: {file_type}. Must be 'duration' or 'inflows'.")

    def find_latest_df_file(self):
        """Find the most recent file matching df_{file_type}_filenames*.csv in the directory"""
        all_files = os.listdir(self.path_df_dir)
        print("Files in directory:", all_files)

        """Find the most recent file matching df_{file_type}_filenames*.csv in the directory"""
        files = [f for f in os.listdir(self.path_df_dir) if f.startswith(self.file_prefix) and f.endswith(".csv")]
        print("Filtered files:", files)

        if not files:
            raise FileNotFoundError(f"No {self.file_prefix}_yyyymmdd.csv files found in the specified directory")
        
        # Extract dates from filenames
        date_files = []
        for file in files:
            match = re.search(r'(\d{8})', file)  # Try YYYYMMDD first
            if match:
                file_date = datetime.strptime(match.group(1), '%Y%m%d').date()
                date_files.append((file_date, file))
            else:
                match = re.search(r'(\d{4}-\d{2}-\d{2})', file)  # Try YYYY-MM-DD
                if match:
                    file_date = datetime.strptime(match.group(1), '%Y-%m-%d').date()
                    date_files.append((file_date, file))
                else:
                    date_files.append((None, file))  # Unrecognized format

        # Sort by date, prioritating dated files
        dated_files = [item for item in date_files if item[0] is not None]
        print("Files with parsed dates:", dated_files)

        if not dated_files:
            raise FileNotFoundError(f"No valid dated files found in {self.path_df_dir}")
        # Sort dated files in descending order (latest first)
        dated_files.sort(key=lambda x: x[0], reverse=True)
        latest_file = dated_files[0][1]
        print(f"Latest file selected: {latest_file}")
        # date_files.sort(key=lambda x: (x[0] is None, x[0]), reverse=True)
        # date_files.sort(key=lambda x: (x[0] if x[0] else datetime.min.date()), reverse=True)

        # latest_file = date_files[0][1] # Get the latest dated file
        return os.path.join(self.path_df_dir, latest_file)

    def load_df(self):
        """Load the latest DataFrame from CSV"""
        self.df = pl.read_csv(self.path_df, dtypes=[pl.Utf8, pl.Date])
    
    def get_new_filenames(self):
        """ Retreive new fienames that match the format and ar not in df"""
        return [
            name for name in os.listdir(self.path_new_filenames)
            if re.match(self.new_filenames_format, name)
            and name not in self.df['filename']
        ]

    def parse_dates(self, filenames):
        """Extract dates from filenames using multiple patterns"""
        pattern0 = r'(202\d{5})'
        pattern1 = r'(202\d{1}_\d{2}_\d{2})'

        def parse_date_string(string):
            if match := re.search(pattern0, string):
                return datetime.strptime(match.group(0), '%Y%m%d').date()
            if match := re.search(pattern1, string):
                return datetime.strptime(match.group(0), '%Y_%m_%d').date()
            else:
                return None
        
        return [parse_date_string(filename) for filename in filenames]
    
    def sort_df_by_date(self):
        """Sort DataFrame by date in descending order"""
        self.df = self.df.sort('date', descending=True)

    def write_df(self):
        """Write the updated DataFrame to CSV, keeping the latest date format"""
        self.sort_df_by_date()
        self.df.write_csv(self.path_df)

    def update_filenames(self, since_last_days=60):
        """Update filenames in the DataFrame with new ones"""
        filenames = self.get_new_filenames()
        dates = self.parse_dates(filenames)
        date_60_days_ago = date.today() - timedelta(days=since_last_days)
        new_df = (
            pl.DataFrame({'filename': filenames, 'date': dates})
            .filter(pl.col('date') > date_60_days_ago)
        )
        if len(new_df) > 0:
            print(f'added {new_df}')
            self.df = self.df.extend(new_df)
            self.sort_df_by_date()
            self.write_df()

    def to_list(self):
        """Return the list of filenames from the DataFrame"""
        return self.df['filename'].to_list()

    def __call__(self):
        """Return the DataFrame when the object is called"""
        return self.df